import torch
from torch import nn
import argparse
import torch.optim
import os
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
from dataset import cifar100_dataset
from train import train
from test import test
import SKNet
import FcaNet_SKNet
import SKNet_Res2Net

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", default="../data",
                        type=str, help="The input data dir")
    parser.add_argument("--batch_size", default=128,
                        type=int, help="The batch size of training")
    parser.add_argument("--device", default='cuda',
                        type=str, help="The training device")
    parser.add_argument("--learning_rate", default=0.05,
                        type=float, help="learning rate")
    parser.add_argument("--epochs", default=300,
                        type=int, help="Training epoch")
    parser.add_argument("--modeldir", default="./model", type=str)
    args = parser.parse_known_args()[0]

    train_loader, test_loader = cifar100_dataset(args)

    writer = SummaryWriter(os.path.join(args.modeldir, "tensorboard"))
    net = FcaNet_SKNet.SKNet26().to(args.device)
    criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    optimizer = torch.optim.SGD(
        net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=5e-4)
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=20, last_epoch=-1, eta_min=1e-6)
    if not os.path.exists(args.modeldir):
        os.makedirs('./model')
    lossv, accv = [], []
    index_num = 0
    correct_max = 0.
    PATH = os.path.join(args.modeldir, 'cifar_net.pth')
    for epoch in range(args.epochs):
        # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.01, last_epoch=epoch-1)
        train(net, train_loader, optimizer, criterion,
              writer, args, epoch, index_num)
        with torch.no_grad():
            correct, loss = test(net, test_loader, criterion,
                                 writer, args, epoch, lossv, accv)
        lr_scheduler.step()
        print(optimizer.state_dict()['param_groups'][0]['lr'])
        if correct > correct_max:
            torch.save(net.state_dict(), PATH)
    plt.figure(figsize=(5, 3))
    plt.plot(np.arange(1, args.epochs + 1), lossv)
    plt.title('validation loss')
    plt.savefig(os.path.join(args.modeldir, 'validation_loss'))

    plt.figure(figsize=(5, 3))
    plt.plot(np.arange(1, args.epochs + 1), accv)
    plt.title('validation accuracy')
    plt.savefig(os.path.join(args.modeldir, 'validation_accuracy'))
